Skip to content

Fix subgroup optimizer metadata inconsistency#7820

Draft
st-bang97 wants to merge 5 commits intodeepspeedai:masterfrom
st-bang97:stbang/async_update
Draft

Fix subgroup optimizer metadata inconsistency#7820
st-bang97 wants to merge 5 commits intodeepspeedai:masterfrom
st-bang97:stbang/async_update

Conversation

@st-bang97
Copy link
Copy Markdown
Contributor

Description

This PR addresses Issue #7819.

When using ZeRO Stage-3 + CPU offloaded optimizer (CPUAdam) with 2+ subgroups, ds_adam_step can be invoked multiple times within a single global optimizer step. Before this fix, the internal bias-correction-related state (e.g., _betta2_t_bias_correction2) could become inconsistent across subgroup invocations within the same step, leading to subgroup-wise optimizer state divergence.


Solution

Make IncrementStep() step-consistent under repeated calls in the same global step.

Change

Update IncrementStep() to only advance or recompute state when step != _step, preventing subgroup-to-subgroup drift inside a single step.

inline void IncrementStep(size_t step, float beta1, float beta2)
{
    if (beta1 != _betta1 || beta2 != _betta2) {
        _step = step;
        _betta1 = beta1;
        _betta2 = beta2;
        _betta1_t = std::pow(_betta1, step);
        _betta2_t = std::pow(_betta2, step);
    } else {
        if (step != _step) {
            _step++;
            if (_step != step) {
                _betta1_t = std::pow(_betta1, step);
                _betta2_t = std::pow(_betta2, step);
                _step = step;
            } else {
                _betta1_t *= _betta1;
                _betta2_t *= _betta2;
            }
        }
    }
}
  1. Reproduction (Before Fix)
스크린샷 2026-01-27 194552 스크린샷 2026-01-27 194716
  1. Verification (After Fix)
스크린샷 2026-01-27 200422 스크린샷 2026-01-27 200410

@st-bang97 st-bang97 requested a review from tjruwase as a code owner January 27, 2026 11:19
@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Feb 18, 2026

Hi @st-bang97,
Thank you for the contribution!
Sorry, I just noticed you've already addressed #7819. Please disregard my comment.

Overall this PR looks good to me. Here are some suggestions:

@st-bang97
Copy link
Copy Markdown
Contributor Author

Thanks for the review and feedback. I’m currently finalizing a paper submission and preparing for reviews, so I won’t be able to address this immediately. I’ll take a closer look and update the PR as soon as I have time.

@st-bang97 st-bang97 marked this pull request as draft April 21, 2026 05:36
delock pushed a commit that referenced this pull request Apr 23, 2026
Fix CPUAdam same-step subgroup drift in ZeRO-3 (#7819)

This PR ports the fix from #7820 to the latest DeepSpeed version.

It makes `Adam_Optimizer::IncrementStep` idempotent for repeated calls
at the same logical step and avoids unnecessary recomputation when the
step has not changed.

ZeRO-3/SuperOffload can invoke multiple subgroup updates within a single
logical step on a shared native optimizer object. The previous logic
mixed multiply and recompute paths, producing non-bit-identical
bias-correction metadata across subgroup calls.

This change aligns the step-transition logic in both the CPU and XPU
headers, clarifies first-step and non-sequential-step behavior, and
prevents unnecessary work on repeated same-step updates.

It also adds CPUAdam regression tests covering subgroup-style repeated
same-step updates through both `step_subgroup()` and `step()` with
parameter swapping.

Signed-off-by: st_bang <st.bang@dgist.ac.kr>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants